﻿/*******************************************************************************
* Modifications Copyright (c) 2021 Advanced Micro Devices, Inc. All rights reserved.
* Notified per clause 4(b) of the license.
*******************************************************************************/

/*******************************************************************************
* Copyright 2019-2020 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include <atomic>

#include <assert.h>
#include <float.h>
#include <math.h>

#include "common/c_types_map.hpp"
#include "common/zendnn_thread.hpp"
#include "common/type_helpers.hpp"
#include "common/utils.hpp"

#include "cpu/cpu_primitive.hpp"

#include "cpu/gemm/gemm.hpp"

#include "cpu/matmul/zendnn_f32_matmul.hpp"
#include "cpu/matmul/matmul_utils.hpp"

#include "zendnn_logging.hpp"
#include "zendnn_private.hpp"

namespace zendnn {
namespace impl {
namespace cpu {
namespace matmul {

using namespace data_type;

status_t zendnn_f32_matmul_t::pd_t::init(engine_t *engine) {
    zendnnInfo(ZENDNN_CORELOG, "zendnn_f32_matmul_t::pd_t::init()");
    auto check_bias = [&]() -> bool {
        return !with_bias()
        || (weights_md(1)->data_type == f32 && is_bias_1xN());
    };

    bool ok = src_md()->data_type == src_type
              && weights_md()->data_type == weights_type
              && desc()->accum_data_type == acc_type
              && dst_md()->data_type == dst_type && check_bias()
              && attr()->has_default_values(
                  primitive_attr_t::skip_mask_t::oscale_runtime
                  | primitive_attr_t::skip_mask_t::post_ops)
              && set_default_formats()
              && gemm_based::check_gemm_compatible_formats(*this);

    if (!ok) return status::unimplemented;

    // set state
    params_.dst_is_acc_ = true;
    if (!has_runtime_dims_or_strides())
        params_.can_fuse_src_batch_dims_
                = matmul_helper_t(src_md(), weights_md(), dst_md())
                          .can_fuse_src_batch_dims();

    return check_and_configure_attributes();
}

status_t zendnn_f32_matmul_t::pd_t::check_and_configure_attributes() {
    zendnnInfo(ZENDNN_CORELOG,
               "zendnn_gemm_f32_matmul_t::pd_t::check_and_configure_attributes");
    auto check_attr_oscale = [&]() -> bool {
        const auto &oscale = attr()->output_scales_;
        return oscale.mask_ == 0
                || (oscale.mask_ == (1 << 1) && batched() == false);
    };

    auto check_attr_post_ops = [&]() -> bool {
        using namespace primitive_kind;
        const auto &p = attr()->post_ops_;
        auto check_sum = [&](int idx) -> bool {
            return p.contain(sum, idx) && params_.gemm_applies_output_scales_;
        };
        switch (p.len()) {
            case 0: return true;
            case 1: return check_sum(0) || p.contain(eltwise, 0);
            case 2: return check_sum(0) && p.contain(eltwise, 1);
            default: return false;
        }
    };

    // check basic attributes
    if (!check_attr_oscale()) return status::unimplemented;

    // set state
    CHECK(params_.pp_attr_.copy_from(*attr()));
    params_.gemm_applies_output_scales_
#if ZENDNN_ENABLE
// ZenDNN MatMul can support scalar alpha values with bias
        = attr()->output_scales_.mask_ == 0;
#else
        = attr()->output_scales_.mask_ == 0 && !with_bias();
#endif
    if (params_.gemm_applies_output_scales_)
        params_.pp_attr_.output_scales_.set(1.f);

    // check post-ops
    if (check_attr_post_ops()) {
        auto &po = params_.pp_attr_.post_ops_;
        const int sum_idx = 0;
        if (po.len() > 0 && po.contain(primitive_kind::sum, sum_idx)) {
            // set state
            params_.gemm_beta_ = po.entry_[sum_idx].sum.scale;
            // drop sum from pp_attributes, as it will be applied by gemm
            po.entry_.erase(po.entry_.begin());
        }
    } else {
        return status::unimplemented;
    }

    // set state
    params_.has_pp_kernel_
            = with_bias() || !params_.pp_attr_.has_default_values();

    return status::success;
}

status_t zendnn_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
    auto src = CTX_IN_MEM(const src_data_t *, ZENDNN_ARG_SRC);
    auto weights = CTX_IN_MEM(const weights_data_t *, ZENDNN_ARG_WEIGHTS);
    auto bias = CTX_IN_MEM(const char *, ZENDNN_ARG_BIAS);
    auto dst = CTX_OUT_MEM(dst_data_t *, ZENDNN_ARG_DST);

    DEFINE_SCALES_BUFFER(scales);

    const auto src_d = ctx.memory_mdw(ZENDNN_ARG_SRC, pd()->src_md());
    const auto weights_d = ctx.memory_mdw(ZENDNN_ARG_WEIGHTS, pd()->weights_md());
    const auto dst_d = ctx.memory_mdw(ZENDNN_ARG_DST, pd()->dst_md());

    const gemm_based::params_t &params = pd()->params();

    const auto &dst_bd = dst_d.blocking_desc();

    const bool batched = pd()->batched();

    const dim_t M = dst_d.dims()[batched + 0];
    const dim_t N = dst_d.dims()[batched + 1];
    const dim_t K = src_d.dims()[batched + 1];

    const auto &src_strides = &src_d.blocking_desc().strides[batched];
    const auto &weights_strides = &weights_d.blocking_desc().strides[batched];

    const char *transA
        = src_strides[1] == 1 && src_d.dims()[batched + 0] > 1 ? "N" : "T";
    const char *transB
        = weights_strides[1] == 1 && weights_d.dims()[batched + 0] > 1
          ? "N"
          : "T";

    const int M_s32 = (int)M;
    const int N_s32 = (int)N;
    const int K_s32 = (int)K;

    const int lda = (int)src_strides[*transA == 'N' ? 0 : 1];
    const int ldb = (int)weights_strides[*transB == 'N' ? 0 : 1];
    const int ldc = (int)dst_bd.strides[batched + 0];

    float alpha = params.get_gemm_alpha(scales);
    const float beta = params.gemm_beta_;
    //TODO(aakar): Modify f32_matmul API to include Layout
    const bool Layout = true; // CblasRowMajor

    const dim_t batch = batched ? src_d.dims()[0] : 1;
    const auto src_batch_stride = src_d.blocking_desc().strides[0];
    const auto weights_batch_stride = weights_d.blocking_desc().strides[0];
    const auto dst_batch_stride = dst_d.blocking_desc().strides[0];

    zendnnInfo(ZENDNN_CORELOG, "zendnn_gemm_f32_matmul_t::execute_ref");
    zendnnInfo(ZENDNN_CORELOG, "M: :",M, " N: ",N, " K: ", K);
    zendnnInfo(ZENDNN_CORELOG, "M_s32: :",M_s32, " N_s32: ",N_s32, " K_s32: ",
               K_s32);
    zendnnInfo(ZENDNN_CORELOG, " alpha: :", alpha, " beta: ", beta,
               " Layout: ", Layout ? "CblasRowMajor" : "CblasColMajor");
    bool has_eltwise = pd()->attr()->post_ops_.find(primitive_kind::eltwise) >= 0;
#if ZENDNN_ENABLE
    alpha = pd()->attr()->output_scales_.mask_ == 0 ? scales[0] : 1.0;
    if ((float *)bias == NULL) {
        //MatMul without Bias
        zenMatMul(Layout, strcmp(transA, "N"),strcmp(transB, "N"), batch, M, K,
                N, alpha, (float *)src, lda, (float *)weights, ldb, beta,
                (float *)dst, ldc);
    }
    else if ((float *)bias != NULL && !has_eltwise) {
        //MatMul with Bias
        zenMatMulWithBias(Layout, strcmp(transA, "N"), strcmp(transB, "N"),
                batch, M, K, N, alpha, (float *)src, lda, (float *)weights, ldb,
                (float *)bias, beta, (float *)dst, ldc);
    }
    else {
        //MatMul with BiasRelu
        zenMatMulWithBiasReLU(Layout, strcmp(transA, "N"), strcmp(transB, "N"),
                batch, M, K, N, alpha, (float *)src, lda, (float *)weights, ldb,
                (float *)bias, beta, (float *)dst, ldc);
    }
#else //ZENDNN_ENABLE
    const bool parallel_over_batch = batch > 1;
    if (parallel_over_batch) {
        parallel(parallel_over_batch ? 0 : 1, [&](int ithr, int nthr) {
            size_t batch_start {}, batch_end {};
            balance211((size_t)(batch), nthr, ithr, batch_start, batch_end);
            for (size_t b = batch_start; b < batch_end; ++b) {
                const src_data_t *curr_src = src + b * src_batch_stride;
                const weights_data_t *curr_weights
                    = weights + b * weights_batch_stride;
                dst_data_t *curr_dst = dst + b * dst_batch_stride;

                extended_sgemm(transB, transA, &N_s32, &M_s32, &K_s32, &alpha,
                               curr_weights, &ldb, curr_src, &lda, &beta, curr_dst,
                               &ldc, nullptr, false);

                if (params.has_pp_kernel_) {
                    const float *pp_scales
                        = params.get_post_processing_scales(scales);
                    (*pp_kernel_)(curr_dst, curr_dst, bias, pp_scales, 0, M * N,
                                  (size_t)N);
                }
            }
        });
    }
    else {
        extended_sgemm(transB, transA, &N_s32, &M_s32, &K_s32, &alpha, weights,
                       &ldb, src, &lda, &beta, dst, &ldc, nullptr, false);

        if (params.has_pp_kernel_) {
            const bool force_sequential = pp_kernel_->sequential_kernel();
            const float *pp_scales = params.get_post_processing_scales(scales);
            parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) {
                size_t start {}, end {};
                balance211((size_t)(M * N), nthr, ithr, start, end);
                (*pp_kernel_)(dst, dst, bias, pp_scales, start, end, (size_t)N);
            });
        }
    }
#endif //ZENDNN_ENABLE

    return status::success;
}

} // namespace matmul
} // namespace cpu
} // namespace impl
} // namespace zendnn
